-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Support dspy.Tool
as input field type and dspy.ToolCall
as output field type
#8242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a19db3b
to
c06b65c
Compare
@okhat @TomeHirata I am going to add some adapter change to this PR, will ping you for review after that is done. |
c06b65c
to
b171e9f
Compare
@@ -181,6 +200,44 @@ def __str__(self): | |||
return f"{self.name}{desc} {arg_desc}" | |||
|
|||
|
|||
class ToolCalls(BaseType): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q: Why do we make ToolCalls
a first citizen rather than ToolCall
? Other frameworks basically define ToolCall first (langchain) and then treat tool call response as list[ToolCall]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you spotted the weirdo, and this is a design after multiple edits. I started with dspy.Tool
, but then realized we are leaving the output field in free forms. For example, users might do:
class QAWithMultipleToolCall(dspy.Signature):
question: str = dspy.InputField()
tools: list[dspy.Tool] = dspy.InputField()
answer: str = dspy.OutputField()
tool_call_1: dspy.ToolCall = dspy.OutputField()
tool_call_2: dspy.ToolCall = dspy.OutputField()
while we expect them to do:
class QAWithToolCall(dspy.Signature):
question: str = dspy.InputField()
tools: list[dspy.Tool] = dspy.InputField()
answer: str = dspy.OutputField()
tool_calls: list[dspy.ToolCall] = dspy.OutputField()
The bad thing about QAWithMultipleToolCall
is when we use the native function calling (see JSONAdapter in this PR), we have no clue how to write the tool calls back to the output field. With a regulation that the output field must be dspy.ToolCalls
, in native function calling use case, we can locate the output field and write the tool call info there.
For some more context, there is a caveat that OAI models are bad at outputting dict, so the current JSONAdapter is broken with tool calling. For a quick test, you can try using JSONAdapter + ReAct, in my testing that doesn't really work. Native function calling resolves this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed response. I agree this is tricky. We could have a validation logic in dspy.Signature
and disallow the usage of a single dspy.ToolCall
field. Then, when we use the native function calling, how do we deal with the multiple tool call case? To my understanding, the native function calling does not provide any semantic grouping.
class QAWithToolCall(dspy.Signature):
question: str = dspy.InputField()
tools: list[dspy.Tool] = dspy.InputField()
answer: str = dspy.OutputField()
tier1_tool_calls: list[dspy.ToolCall] = dspy.OutputField()
tier2_tool_calls: list[dspy.ToolCall] = dspy.OutputField()
value = self.parse(signature, text) | ||
else: | ||
value = {} | ||
for field_name in signature.output_fields.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chenmoneygithub qq - do we need to handle this tool-call specific logic here? I'm wondering that since it inherits from BaseType if we can add a parse
function to Tool (and the BaseType interface) and port over all the logic there (similar to the format
we use for custom types)? This could keep post-process generalizable, and we'd just add a check for if the signature includes a BaseType output field and parse accordingly. curious on thoughts here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Output handling is a bit different from input handling, which we can possibly generalize.
However I would be cautious about doing it because we don't know yet if generalization makes sense here - for the ToolCalls
case, we are reading from the LM response and write back to the output field that of type dspy.ToolCalls
if native function calling is used, which may be completely different from the second output field we introduce.
@@ -20,18 +27,78 @@ def __init_subclass__(cls, **kwargs) -> None: | |||
cls.format = with_callbacks(cls.format) | |||
cls.parse = with_callbacks(cls.parse) | |||
|
|||
def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]: | |||
def _call_preprocess( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar q for the preprocess too, could this instead be wrapped in Tool's format
and make use of split_message_content_for_custom_types
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for the tool calling, here we need to handle the special case of native function calling: https://platform.openai.com/docs/guides/function-calling?api-mode=chat, so we need to modify the LM call args in addition to the messages.
return None | ||
|
||
def _get_tool_call_output_field_name(self, signature: Type[Signature]) -> bool: | ||
for name, field in signature.output_fields.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when there are multiple fields with type ToolCalls
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then only one field will be populated with value. We can raise a warning when there are multiple ToolCalls field, I kinda doubt if users will do that though. We made it ToolCalls
as an indicator that it has multiple ToolCall
s.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
had the same question for multimodal. in a single call (diff from n
), i don't believe the models can produce multiple multimodal outputs with 2+ dspy.OutputFields (it'll raise the output exception we see sometimes)
could be different for ToolCalls tho, but might be safer to give that warning globally (maybe for the select ones in types?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with Arnav. For cases where users define multiple output fields with a type that can only have a single field, such as tool calls or multi-modal, can we raise an exception when the invalid signature is created rather than warning since this is a wrong usage of signature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chenmoneygithub Since this PR is merged, can you follow up with another PR on this?
from dspy.primitives.program import Module | ||
from dspy.primitives.tool import Tool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q: do we have isort running on your local? Shall we include in our CI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea I have isort locally, and I remember there are a few files we don't want to sort, we can check after merging #7885
The main goal is making
dspy.Tool
a valid input type with special format so that users can easily usedspy.Tool
without usingdspy.ReAct
.dspy.ToolCall
is just a thing wrapper over tool name and tool args to simplify the output field definition involving tool calls.Sample usage:
With sample output (not including history):